#' ---
#' title: "Regression and Other Stories: Bayesian $R^2$"
#' author: "Andrew Gelman, Jennifer Hill, Aki Vehtari"
#' date: "`r format(Sys.Date())`"
#' output:
#'   html_document:
#'     theme: readable
#'     toc: true
#'     toc_depth: 2
#'     toc_float: true
#'     code_download: true
#' ---

#' Bayesian $R^2$. See Chapter 11 in Regression and Other Stories.
#'
#' See also
#' - Andrew Gelman, Ben Goodrich, Jonah Gabry, and Aki Vehtari (2018).
#'   R-squared for Bayesian regression models. The American Statistician, 73:307-209
#'   [doi:10.1080/00031305.2018.1549100](https://doi.org/10.1080/00031305.2018.1549100).
#' 
#' -------------
#' 

#' ## Introduction
#'
#' Gelman, Goodrich, Gabry, and Vehtari (2018) define Bayesian $R^2$ as
#' $$
#' R^2 = \frac{\mathrm{Var}_{\mu}}{\mathrm{Var}_{\mu}+\mathrm{Var}_\mathrm{res}},
#' $$
#' where $\mathrm{Var}_{\mu}$ is variance of modelled predictive means,
#' and $\mathrm{Var}_\mathrm{res}$ is the modelled residual variance.
#' Specifically both of these are computed only using posterior
#' quantities from the fitted model.
#' The model based $R^2$ uses draws from the model and residual variances.
#' For linear regression $\mu_n=X_n\beta$ we define
#' $$
#' \mathrm{Var}_\mathrm{\mu}^s = V_{n=1}^N \mu_n^s\\
#' \mathrm{Var}_\mathrm{res}^s = (\sigma^2)^s,
#' $$
#' and for logistic regression, following Tjur (2009),
#' we define $\mu_n=\pi_n$ and 
#' $$
#' \mathrm{Var}_\mathrm{\mu}^s = V_{n=1}^N \mu_n^s\\
#' \mathrm{Var}_\mathrm{res}^s = \frac{1}{N}\sum_{n=1}^N (\pi_n^s(1-\pi_n^s)),
#' $$
#' where $\pi_n^s$ are predicted probabilities.
#' 

#+ setup, include=FALSE
knitr::opts_chunk$set(message=FALSE, error=FALSE, warning=FALSE, comment=NA)
# switch this to TRUE to save figures in separate files
savefigs <- FALSE

#' #### Load packages
library("rprojroot")
root<-has_file(".ROS-Examples-root")$make_fix_file()
library("rstanarm")
library("ggplot2")
library("bayesplot")
theme_set(bayesplot::theme_default(base_family = "sans"))
library("foreign")
# for reproducability
SEED <- 1800
set.seed(SEED)
#+ eval=FALSE, include=FALSE
# grayscale figures for the book
if (savefigs) color_scheme_set(scheme = "gray")

#' #### Function for Bayesian R-squared for stan_glm models. 
#' 
#' Bayes-R2 function using modelled (approximate) residual variance
bayes_R2 <- function(fit) {
  mupred <- rstanarm::posterior_linpred(fit, transform = TRUE)
  var_mupred <- apply(mupred, 1, var)
  if (family(fit)$family == "binomial" && NCOL(y) == 1) {
      sigma2 <- apply(mupred*(1-mupred), 1, mean)
  } else {
      sigma2 <- as.matrix(fit, pars = c("sigma"))^2
  }
  var_mupred / (var_mupred + sigma2)
}

#' ## Toy data with n=5
x <- 1:5 - 3
y <- c(1.7, 2.6, 2.5, 4.4, 3.8) - 3
xy <- data.frame(x,y)

#' #### Lsq fit
fit <- lm(y ~ x, data = xy)
ols_coef <- coef(fit)
yhat <- ols_coef[1] + ols_coef[2] * x
r <- y - yhat
rsq_1 <- var(yhat)/(var(y))
rsq_2 <- var(yhat)/(var(yhat) + var(r))
round(c(rsq_1, rsq_2), 3)

#' #### Bayes fit
fit_bayes <- stan_glm(y ~ x, data = xy,
  prior_intercept = normal(0, 0.2, autoscale = FALSE),
  prior = normal(1, 0.2, autoscale = FALSE),
  prior_aux = NULL,
  seed = SEED, refresh = 0
)
posterior <- as.matrix(fit_bayes, pars = c("(Intercept)", "x"))
post_means <- colMeans(posterior)

#' #### Median Bayesian R^2
round(median(bayesR2<-bayes_R2(fit_bayes)), 2)

#' #### Figures
#'
#' The first section of code below creates plots using base R graphics. </br>
#' Below that there is code to produce the plots using ggplot2.

# take a sample of 20 posterior draws
keep <- sample(nrow(posterior), 20)
samp_20_draws <- posterior[keep, ]

#' #### Base graphics version

#+ eval=FALSE, include=FALSE
if (savefigs) pdf("fig/rsquared1a.pdf", height=4, width=5)
#+
par(mar=c(3,3,1,1), mgp=c(1.7,.5,0), tck=-.01)
plot(
  x, y,
  ylim = range(x),
  xlab = "x",
  ylab = "y",
  main = "Least squares and Bayes fits",
  bty = "l",
  pch = 20
)
abline(coef(fit)[1], coef(fit)[2], col = "black")
text(-1.6,-.7, "Least-squares\nfit", cex = .9)
abline(0, 1, col = "blue", lty = 2)
text(-1, -1.8, "(Prior regression line)", col = "blue", cex = .9)
abline(coef(fit_bayes)[1], coef(fit_bayes)[2], col = "blue")
text(1.4, 1.2, "Posterior mean fit", col = "blue", cex = .9)
points(
  x,
  coef(fit_bayes)[1] + coef(fit_bayes)[2] * x,
  pch = 20,
  col = "blue"
)
#+ eval=FALSE, include=FALSE
if (savefigs) dev.off()

#+ eval=FALSE, include=FALSE
if (savefigs) pdf("fig/rsquared1b.pdf", height=4, width=5)
#+
par(mar=c(3,3,1,1), mgp=c(1.7,.5,0), tck=-.01)
plot(
  x, y,
  ylim = range(x),
  xlab = "x",
  ylab = "y",
  bty = "l",
  pch = 20,
  main = "Bayes posterior simulations"
)
for (s in 1:nrow(samp_20_draws)) {
  abline(samp_20_draws[s, 1], samp_20_draws[s, 2], col = "#9497eb")
}
abline(
  coef(fit_bayes)[1],
  coef(fit_bayes)[2],
  col = "#1c35c4",
  lwd = 2
)
points(x, y, pch = 20, col = "black")
#+ eval=FALSE, include=FALSE
if (savefigs) dev.off()

#' #### ggplot version
theme_update(
  plot.title = element_text(face = "bold", hjust = 0.5), 
  axis.text = element_text(size = rel(1.1))
)
fig_1a <-
  ggplot(xy, aes(x, y)) +
  geom_point() +
  geom_abline( # ols regression line
    intercept = ols_coef[1],
    slope = ols_coef[2],
    size = 0.4
  ) +
  geom_abline( # prior regression line
    intercept = 0,
    slope = 1,
    color = "blue",
    linetype = 2,
    size = 0.3
  ) +
  geom_abline( # posterior mean regression line
    intercept = post_means[1],
    slope = post_means[2],
    color = "blue",
    size = 0.4
  ) +
  geom_point(
    aes(y = post_means[1] + post_means[2] * x),
    color = "blue"
  ) +
  annotate(
    geom = "text",
    x = c(-1.6, -1, 1.4),
    y = c(-0.7, -1.8, 1.2),
    label = c(
      "Least-squares\nfit",
      "(Prior regression line)",
      "Posterior mean fit"
    ),
    color = c("black", "blue", "blue"),
    size = 3.8
  ) +
  ylim(range(x)) + 
  ggtitle("Least squares and Bayes fits")
plot(fig_1a)
#+ eval=FALSE, include=FALSE
if (savefigs) ggsave("fig/rsquared1a-gg.pdf", width = 5, height = 4)
#+
fig_1b <-
  ggplot(xy, aes(x, y)) +
  geom_abline( # 20 posterior draws of the regression line
    intercept = samp_20_draws[, 1],
    slope = samp_20_draws[, 2],
    color = "#9497eb",
    size = 0.25
  ) +
  geom_abline( # posterior mean regression line
    intercept = post_means[1],
    slope = post_means[2],
    color = "#1c35c4",
    size = 1
  ) +
  geom_point() +
  ylim(range(x)) + 
  ggtitle("Bayes posterior simulations")
plot(fig_1b)
#+ eval=FALSE, include=FALSE
if (savefigs) ggsave("fig/rsquared1b-gg.pdf", width = 5, height = 4)

#' #### Bayesian R^2 Posterior and median
mcmc_hist(data.frame(bayesR2), binwidth=0.02) + xlim(c(0,1)) +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2)) +
    ggtitle('Bayesian R squared posterior and median')
#+ eval=FALSE, include=FALSE
if (savefigs) ggsave("fig/bayesr2post.pdf", width = 5, height = 4)

#' ## Toy logistic regression example, n=20
set.seed(20)
y<-rbinom(n=20,size=1,prob=(1:20-0.5)/20)
data <- data.frame(rvote=y, income=1:20)
fit_logit <- stan_glm(rvote ~ income, family=binomial(link="logit"), data=data,
                      refresh=0)

#' #### Median Bayesian R^2
round(median(bayesR2<-bayes_R2(fit_logit)), 2)

#' #### Plot posterior of Bayesian R^2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0,1)
mcmc_hist(data.frame(bayesR2), binwidth=0.02) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))

#' ## Mesquite - linear regression
#' 
#' Predicting the yields of mesquite bushes, n=46
#' 

#' #### Load data
mesquite <- read.table(root("Mesquite/data","mesquite.dat"), header=TRUE)
mesquite$canopy_volume <- mesquite$diam1 * mesquite$diam2 * mesquite$canopy_height
mesquite$canopy_area <- mesquite$diam1 * mesquite$diam2
mesquite$canopy_shape <- mesquite$diam1 / mesquite$diam2
(n <- nrow(mesquite))

#' #### Predict log weight model with log canopy volume, log canopy shape, and group
fit_5 <- stan_glm(log(weight) ~ log(canopy_volume) + log(canopy_shape) +
    group, data=mesquite, refresh=0)

#' #### Median Bayesian R2
round(median(bayesR2<-bayes_R2(fit_5)), 2)

#' #### Plot posterior of Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0.6, 0.95)
mcmc_hist(data.frame(bayesR2), binwidth=0.01) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))

#' ## LowBwt -- logistic regression
#' 
#' Predict low birth weight, n=189, from Hosmer et al (2000).
#' This data was also used by Tjur (2009)
#' 

#' #### Load data
lowbwt <- read.table(root("LowBwt/data","lowbwt.dat"), header=TRUE)
(n <- nrow(lowbwt))
head(lowbwt)

#' #### Predict low birth weight
fit_1 <- stan_glm(low ~ age + lwt + factor(race) + smoke,
                family=binomial(link="logit"), data=lowbwt, refresh=0)

#' #### Median Bayesian R2
round(median(bayesR2<-bayes_R2(fit_1)), 2)

#' #### Plot posterior of Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0, 0.3)
mcmc_hist(data.frame(bayesR2), binwidth=0.01) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))

#' ## LowBwt -- linear regression
#' 
#' Predict birth weight, n=189, from Hosmer et al. (2000).
#' Tjur (2009) used logistic regression for dichotomized birth weight.
#' Below we use the continuos valued birth weight.
#' 

#' #### Predict birth weight
fit_2 <- stan_glm(bwt ~ age + lwt + factor(race) + smoke, data=lowbwt, refresh=0)

#' #### Median Bayesian R2
round(median(bayesR2<-bayes_R2(fit_2)), 2)

#' #### Plot posterior of Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0, 0.36)
mcmc_hist(data.frame(bayesR2), binwidth=0.01) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))

#' ## KidIQ - linear regression
#' 
#' Children's test scores data, n=434
#' 

#' #### Load children's test scores data
kidiq <- read.csv(root("KidIQ/data","kidiq.csv"))
(n <- nrow(kidiq))
head(kidiq)

#' #### Predict test score
fit_3 <- stan_glm(kid_score ~ mom_hs + mom_iq, data=kidiq,
                  seed=SEED, refresh=0)

#' #### Median Bayesian R2
round(median(bayesR2<-bayes_R2(fit_3)), 2)

#' #### Plot posterior of Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0.05, 0.35)
mcmc_hist(data.frame(bayesR2), binwidth=0.01) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))


#' #### Add five pure noise predictors to the data
set.seed(1507)
n <- nrow(kidiq)
kidiqr <- kidiq
kidiqr$noise <- array(rnorm(5*n), c(n,5))

#' #### Linear regression with additional noise predictors
fit_3n <- stan_glm(kid_score ~ mom_hs + mom_iq + noise, data=kidiqr,
                   seed=SEED, refresh=0)
print(fit_3n)

#' #### Median Bayesian R2
round(median(bayesR2n<-bayes_R2(fit_3n)), 2)

#' Median Bayesian R2 is higher with additional noise predictors, but
#' the distribution of Bayesian R2 reveals that the increase is not
#' practically relevant.
#' 

#' #### Plot posterior of Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0.05, 0.35)
mcmc_hist(data.frame(bayesR2n), binwidth=0.01) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))

#' ## Earnings - logistic and linear regression
#' 
#' Predict respondents' yearly earnings using survey data from 1990.</br>
#' logistic regression n=1374, linear regression n=1187
#' 

#' #### Load data
earnings <- read.csv(root("Earnings/data","earnings.csv"))
(n <- nrow(earnings))
head(earnings)

#' #### Bayesian logistic regression on non-zero earnings</br>
#' Predict using height and sex
fit_1a <- stan_glm((earn>0) ~ height + male,
                   family = binomial(link = "logit"),
                   data = earnings, refresh=0)

#' #### Median Bayesian R2
round(median(bayesR2<-bayes_R2(fit_1a)), 3)

#' #### Plot posterior of Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0.02, 0.11)
mcmc_hist(data.frame(bayesR2), binwidth=0.002) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))

#' #### Bayesian probit regression on non-zero earnings</br>
fit_1p <- stan_glm((earn>0) ~ height + male,
                   family = binomial(link = "probit"),
                   data = earnings, refresh=0)

#' #### Median Bayesian R2
round(median(bayesR2), 3)
round(median(bayesR2p<-bayes_R2(fit_1p)), 3)

#' #### Compare logistic and probit models using new Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0.02, 0.11)
p1<-mcmc_hist(data.frame(bayesR2), binwidth=0.002) + pxl +
    scale_y_continuous(breaks=NULL) +
    ggtitle('Earnings data with n=1374') +
    xlab('Bayesian R2 for logistic model') +
    geom_vline(xintercept=median(bayesR2))
p2<-mcmc_hist(data.frame(bayesR2p), binwidth=0.002) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2 for probit model') +
    geom_vline(xintercept=median(bayesR2p))
bayesplot_grid(p1,p2)

#' There is no practical difference in predictive performance between
#' logit and probit.
#' 

#' #### Bayesian model on positive earnings on log scale
fit_1b <- stan_glm(log(earn) ~ height + male, data = earnings, subset=earn>0,
                   refresh=0)

#' #### Median Bayesian R2
round(median(bayesR2<-bayes_R2(fit_1b)), 3)

#' #### Plot posterior of Bayesian R2
#+ message=FALSE, error=FALSE, warning=FALSE
pxl<-xlim(0.02, 0.15)
mcmc_hist(data.frame(bayesR2), binwidth=0.002) + pxl +
    scale_y_continuous(breaks=NULL) +
    xlab('Bayesian R2') +
    geom_vline(xintercept=median(bayesR2))
